In [ ]:
import torch as t
from datasets import load_dataset
import huggingface_hub as hf
from pathlib import Path
import einops
import plotly.graph_objects as go
from typing import Union, List, Optional
from jaxtyping import Float
from transformer_lens import ActivationCache
import circuitsvis as cv
from IPython.display import HTML
from plotly.subplots import make_subplots

from othello_gpt.data.vis import plot_in_basis, plot_game
from othello_gpt.util import (
    get_all_squares,
    load_model,
    load_probes,
    vocab_to_board,
)
from othello_gpt.data.vis import move_id_to_text
In [ ]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
probe_dir = data_dir / "probes"

# hf.login((root_dir / "secret.txt").read_text())
dataset_dict = load_dataset("awonga/othello-gpt")

device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)

size = 6
all_squares = get_all_squares(size)
Resolving data files:   0%|          | 0/87 [00:00<?, ?it/s]
Resolving data files:   0%|          | 0/87 [00:00<?, ?it/s]
Loading dataset shards:   0%|          | 0/87 [00:00<?, ?it/s]
In [ ]:
model = load_model(device, "awonga/othello-gpt-2M")
n_layer = model.cfg.n_layers
n_head = model.cfg.n_heads
d_head = model.cfg.d_head
d_model = model.cfg.d_model
n_neuron = model.cfg.d_model * 4
number of parameters: 1.58M
In [ ]:
n_test = 100
test_dataset = dataset_dict["test"].take(n_test)

probes = load_probes(
    probe_dir,
    device,
    w_u=model.W_U.detach(),
    w_e=model.W_E.T.detach(),
    w_p=model.W_pos.T.detach(),
    # combos=["t+m", "t-m", "t-e", "t-pt", "m-pm"],
    combos=["+pee-ee"],
)
{k: p.shape for k, p in probes.items()}  # d_model (row col) n_probe_layer
Out[ ]:
{'ptm': torch.Size([128, 36, 17]),
 'tm': torch.Size([128, 36, 17]),
 'ee': torch.Size([128, 36, 17]),
 'le': torch.Size([128, 36, 17]),
 'pee': torch.Size([128, 36, 17]),
 'tnpt': torch.Size([128, 36, 17]),
 'u': torch.Size([128, 36, 17]),
 'b': torch.Size([128, 36, 17]),
 'p': torch.Size([128, 31, 17]),
 '+pee-ee': torch.Size([128, 36, 17])}
In [ ]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[t.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: t.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[t.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = [move_id_to_text(t, size) for t in local_tokens]

    # Combine the patterns into a single tensor
    patterns: Float[t.Tensor, "head_index dest_pos src_pos"] = t.stack(
        patterns, dim=0
    ).cpu()

    # Normalise relative to 1/pos such that later rows don't get diluted
    patterns *= (t.arange(patterns.shape[1]) + 1).unsqueeze(0).unsqueeze(-1)

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = cv.circuitsvis.attention.attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"
In [ ]:
for i in range(3):
    test_game = test_dataset[i]
    test_input_ids = t.tensor(test_game["input_ids"], device=device)
    test_logits, test_cache = model.run_with_cache(test_input_ids[:-1])
    vis = visualize_attention_patterns(
        list(range(model.cfg.n_layers * model.cfg.n_heads)),
        test_cache,
        test_game["moves"],
    )
    display(HTML(vis))
    plot_game(test_game)